{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms, models\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.autograd import Variable\n",
    "import copy\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadimg(path = None):\n",
    "    img = Image.open(path)\n",
    "    img = transform(img)\n",
    "    img = img.unsqueeze(0)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.Scale([224,224]),\n",
    "                                transforms.ToTensor()])\n",
    "\n",
    "content_img = loadimg(\"images/4.jpg\")\n",
    "content_img = Variable(content_img).cuda()\n",
    "style_img = loadimg(\"images/1.jpg\")\n",
    "style_img = Variable(style_img).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Content_loss(torch.nn.Module):\n",
    "    def __init__(self, weight, target):\n",
    "        super(Content_loss, self).__init__()\n",
    "        self.weight = weight\n",
    "        self.target = target.detach()*weight\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "        \n",
    "    def forward(self, input):\n",
    "        self.loss = self.loss_fn(input*self.weight, self.target)\n",
    "        return input\n",
    "    \n",
    "    def backward(self):\n",
    "        self.loss.backward(retain_graph = True)\n",
    "        return self.loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Gram_matrix(torch.nn.Module):\n",
    "    \n",
    "    def forward(self, input):\n",
    "        a,b,c,d = input.size()\n",
    "        feature = input.view(a*b, c*d)\n",
    "        gram = torch.mm(feature, feature.t())\n",
    "        return gram.div(a*b*c*d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Style_loss(torch.nn.Module):\n",
    "    \n",
    "    def __init__(self, weight, target):\n",
    "        super(Style_loss, self).__init__()\n",
    "        self.weight = weight\n",
    "        self.target = target.detach()*weight\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "        self.gram = Gram_matrix()\n",
    "        \n",
    "    def forward(self, input):\n",
    "        self.Gram = self.gram(input.clone())\n",
    "        self.Gram.mul_(self.weight)\n",
    "        self.loss = self.loss_fn(self.Gram, self.target)\n",
    "        return input\n",
    "    \n",
    "    def backward(self):\n",
    "        self.loss.backward(retain_graph = True)\n",
    "        return self.loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_gpu = torch.cuda.is_available()\n",
    "cnn = models.vgg16(pretrained=True).features\n",
    "\n",
    "if use_gpu:\n",
    "        cnn = cnn.cuda()\n",
    "        model = copy.deepcopy(cnn)\n",
    "\n",
    "content_layer = [\"Conv_3\"]\n",
    "style_layer = [\"Conv_1\", \"Conv_2\", \"Conv_3\", \"Conv_4\"]\n",
    "content_losses = []\n",
    "style_losses = []\n",
    "conten_weight = 1\n",
    "style_weight = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_model = torch.nn.Sequential()\n",
    "model = copy.deepcopy(cnn)\n",
    "gram = Gram_matrix()\n",
    "\n",
    "if use_gpu:\n",
    "    new_model = new_model.cuda()\n",
    "    gram = gram.cuda()\n",
    "    \n",
    "index = 1\n",
    "for layer in list(model)[:8]:\n",
    "    if isinstance(layer, torch.nn.Conv2d):\n",
    "        name = \"Conv_\"+str(index)\n",
    "        new_model.add_module(name, layer)\n",
    "\n",
    "        if name in content_layer:\n",
    "            target = new_model(content_img).clone()\n",
    "            content_loss = Content_loss(conten_weight, target)\n",
    "            new_model.add_module(\"content_loss_\"+str(index), content_loss)\n",
    "            content_losses.append(content_loss)\n",
    "\n",
    "        if name in style_layer:\n",
    "            target = new_model(style_img).clone()\n",
    "            target = gram(target)\n",
    "            style_loss = Style_loss(style_weight, target)\n",
    "            new_model.add_module(\"style_loss_\"+str(index), style_loss)\n",
    "            style_losses.append(style_loss)\n",
    "\n",
    "    if isinstance(layer, torch.nn.ReLU):\n",
    "        name = \"Relu_\"+str(index)\n",
    "        new_model.add_module(name, layer)\n",
    "        index = index+1\n",
    "\n",
    "    if isinstance(layer, torch.nn.MaxPool2d):\n",
    "        name = \"MaxPool_\"+str(index)\n",
    "        new_model.add_module(name, layer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_img = content_img.clone()\n",
    "parameter = torch.nn.Parameter(input_img.data)\n",
    "optimizer = torch.optim.LBFGS([parameter])\n",
    "epoch_n = 300\n",
    "epoch = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while epoch[0] <= epoch_n:\n",
    "        def closure():\n",
    "            optimizer.zero_grad()\n",
    "            style_score = 0\n",
    "            content_score = 0\n",
    "            parameter.data.clamp_(0,1)\n",
    "            new_model(parameter)\n",
    "\n",
    "            for sl in style_losses:\n",
    "                style_score += sl.backward()\n",
    "            \n",
    "            for cl in content_losses:\n",
    "                content_score += cl.backward()\n",
    "\n",
    "            epoch[0] += 1\n",
    "\n",
    "            if epoch[0] % 50 == 0:\n",
    "                print('Epoch:{} Style Loss: {:4f} Content Loss:{:4f}'.format(epoch[0],style_score.data[0], content_score.data[0]))\n",
    "\n",
    "            return style_score+content_score\n",
    "\n",
    "        optimizer.step(closure)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
